package com.niit.rdd

import org.apache.spark.{SparkConf, SparkContext}

object SparkRDD_Transform_04 {

  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("spark")
    val sc = new SparkContext(sparkConf)
    sc.setLogLevel("ERROR")
    /*
        aggregateByKey：求分区间的平均值
     */
   val rdd1 = sc.makeRDD( List(
            ("a",1),("a",2),("b",3),
            ("b",4),("b",5),("a",6)
      ) ,2 )
    // (a,3)  (b,4)
    /*
    分区内：
        t == (0,0)
        第一个分区  (0,0):第一个0，表示的 值 a: 0+1+2  b:0+3
        第一个分区  (0,0):第二个0，表示的键出现的次数 a: 0+1+1  b:0+1
        第二个分区  (0,0):第一个0，表示的 值 a: 0+6  b:0+4+5
        第二个分区  (0,0):第二个0，表示的键出现的次数 a: 0+1  b:0+1+1
      结果 第一分区 a:(3,2)  b:(3,1)
          第二分区 a:(6,1)  b:(9,2)

      分区间：
          t1==表示第一个分区
          t2==表示第二个分区
          a: (3 + 6 ,2 + 1)  (9,3)
          b: (3 + 9,1 + 2)   (12,3)

     */
    val aggRdd = rdd1.aggregateByKey( (0,0) ) (
      (t,v)=>{
        (t._1+v,t._2+1)
      },
      (t1,t2)=>{
        (t1._1+t2._1,t1._2 + t2._2)
      }
    )
    //aggRdd==> [ (9,3)  ,   (12,3)]
    val resRdd = aggRdd.mapValues{
      case (num,cnt) =>{
        num/cnt
      }
    }
    resRdd.collect().foreach(println)
    /*
        cogroup:分组拼接 (conect group)
     */

    val rdd2 = sc.makeRDD( List( ("a",1),("b",5),("c",3),("g",7)
                      )  )
    val rdd3 = sc.makeRDD( List( ("a",4),("b",5),("c",6)
    ) )

    val coRdd = rdd2.cogroup(rdd3);
    coRdd.collect().foreach(println)



    sc.stop()
  }

}
